import numpy as np

'''
Use Cholesky decomposition to calculate the cost function for continuous data
'''


cost_dict = {}


def cost(d, J, i):
    global cost_dict
    if (i, str(J)) in cost_dict.keys():
        return cost_dict[(i, str(J))]
    """Calculate the cost for a given parent set J and node i."""
    ndata = d.ndata
    data = d.data
    data_i = data[:, [i]]
    if len(J) == 0:
        cost_dict[(i, str(J))] = np.var(data_i)
        return cost_dict[(i, str(J))]
    data_J = data[:, J]
    data_iJ = np.concatenate((data_i, data_J), axis=1)
    cov_iJ = np.cov(data_iJ.T) * (ndata-1)/ndata
    cov_J = np.cov(data_J.T)* (ndata-1)/ndata
    if len(J) == 1:
        cost_dict[(i, str(J))] = np.linalg.det(cov_iJ) / cov_J
        return cost_dict[(i, str(J))]
    cost_dict[(i, str(J))] = np.linalg.det(cov_iJ) / np.linalg.det(cov_J)
    return cost_dict[(i, str(J))]
    

'''
evaluate the Lovarz extension function and calculate the subgradient
'''

def g(d, x, i, C_set, lambda_c, regu_Lambda):
    sigma = np.argsort(-x)
    s = quick_subgrad_1(d, i, sigma, regu_Lambda) + quick_subgrad_2(d, i, sigma, C_set, lambda_c)
    return np.dot(x, s)+np.log(np.var(d.data[:,i])), s

def quick_subgrad_1(d, i, sigma_origin, regu_Lambda): # g with log determinant and regularize
    n = d.n
    ndata = d.ndata
    sigma = sigma_origin.copy()
    for k in range(n-1):
        if sigma[k] >= i:
            sigma[k] += 1
    sigma_full = np.insert(sigma,0,i)
    permuted_data = d.data[:, sigma_full]
    cov = np.cov(permuted_data.T)* (ndata-1)/ndata
    diag_L = np.diag(np.linalg.cholesky(cov))

    s = np.zeros(n-1) # subgradient for G
    # print(sigma)
    temp = np.sum(2*np.log(diag_L[0]))
    for k in range(1, n):
        temp1 = np.sum(2*np.log(diag_L[:k+1]))
        if sigma[k-1] > i:
            s[sigma[k-1]-1] = temp1 - temp + regu_Lambda * 2/ndata
        else:
            s[sigma[k-1]] = temp1 - temp + regu_Lambda * 2/ndata
        temp = temp1
    return s # return function evaluation and subgradient

def quick_subgrad_2(d, i, sigma_origin, C_set, lambda_c): # g with lambda_c
    n = d.n
    ndata = d.ndata
    sigma = sigma_origin.copy()
    for k in range(n-1):
        if sigma[k] >= i:
            sigma[k] += 1
    temp = 0
    s = np.zeros(n-1)
    for k in range(1, n):
        temp1 = 0
        for c_index in range(len(C_set)):
            C = C_set[c_index]
            if i in C:
                add = 0
                for k_index in sigma[range(k)]:
                    if k_index in C:
                        add = 1
                temp1 += add * lambda_c[c_index]* 2/ndata
        if sigma[k-1] > i:
            s[sigma[k-1]-1] = temp1 - temp
        else:
            s[sigma[k-1]] = temp1 - temp
        temp = temp1
    return s

def h(d, x, i):
    n = d.n
    ndata = d.ndata
    sigma = np.argsort(-x)
    # print(sigma)
    for k in range(n-1):
        if sigma[k] >= i:
            sigma[k] += 1
    # print(sigma)
    permuted_data = d.data[:, sigma]
    cov = np.cov(permuted_data.T)* (ndata-1)/ndata
    # print(cov.shape)
    diag_L = np.diag(np.linalg.cholesky(cov))

    y = np.zeros(n-1) # subgradient for H
    temp = 0
    for k in range(1, n):
        temp1 = np.sum(2*np.log(diag_L[:k]))
        if sigma[k-1] > i:
            y[sigma[k-1]-1] = temp1 - temp
        else:
            y[sigma[k-1]] = temp1 - temp
        temp = temp1
    return np.dot(x, y), y

# def linear_subgrad(d, i):
#     data = d.data
#     ndata = d.ndata
#     n = d.n
#     data = np.delete(data,i,1)
#     N = list(range(n-1)) # ground set
#     H_N = np.log(np.linalg.det(np.cov(data.T)* (ndata-1)/ndata))

#     subgrad = np.zeros(n-1)
#     for j in list(range(n-1)):
#         Nj = N.copy() # N - {j} for j in J
#         Nj.remove(j)
#         data_Nj = data[:, Nj]
#         cov_Nj = np.cov(data_Nj.T)* (ndata-1)/ndata
#         subgrad[j] = np.log(np.linalg.det(cov_Nj)) - H_N
#     return subgrad

# def h_linear(d, x, i):
#     func_val, func_subgrad = h(d, x, i)
#     sl = linear_subgrad(d, i)
#     return func_val + np.dot(sl.T, x), func_subgrad+sl

# def rho(d, x, i, j):
#     xj = x.copy()
#     xj[j] = 1 # S union j
#     return h_linear(d, xj, i)[0] - h_linear(d, x, i)[0]









